import random
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
# import torch.nn.functional as F # Not explicitly used in the provided snippet now
import json
from tqdm import tqdm
from collections import Counter
import argparse
import math
import os
import torch.nn as nn
import csv # For CSV output

class RNNClassifier(nn.Module):
    def __init__(self, hidden_dim=256, rnn_type='GRU'):
        super().__init__()
        self.rnn_type = rnn_type.upper()
        if self.rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_dim, batch_first=True)
        else:
            self.rnn = nn.GRU(input_size=1, hidden_size=hidden_dim, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, lengths):
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        if self.rnn_type == 'LSTM':
            packed_out, (hn, cn) = self.rnn(packed_x)
        else:
            packed_out, hn = self.rnn(packed_x)
        last_hidden = hn[-1]
        out = self.classifier(last_hidden)
        return torch.sigmoid(out).squeeze(-1)


class DeepSetClassifier(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, lengths):
        phi_x = self.phi(x)
        mask = torch.arange(x.size(1)).unsqueeze(0).to(x.device) < lengths.unsqueeze(1)
        mask = mask.unsqueeze(-1)
        phi_x = phi_x * mask
        agg = phi_x.sum(dim=1) / lengths.unsqueeze(-1)
        out = self.rho(agg)
        return torch.sigmoid(out).squeeze(-1)

# Global verifier, initialized once
verifier_model_global: DeepSetClassifier = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Heads import (ensure this path is correct in your environment)
from heads import get_matching_head

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True, device=DEVICE)
        # self.embedding_model = self.embedding_model.to(DEVICE) # Already set by device=DEVICE
        self.embedding_model.eval()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.matching_head = get_matching_head("cos_sim", embedding_dim)
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt", map_location=DEVICE))
        self.matching_head = self.matching_head.to(DEVICE)
        self.matching_head.eval()

        self.tokenid2emb = self._build_token_embedding_cache(model_dir)

    def _build_token_embedding_cache(self, model_dir):
        cache_path = os.path.join(model_dir, "tokenid_embedding_cache.pt")
        if os.path.exists(cache_path):
            print(f"📦 Loading token ID embedding cache from {cache_path} ...")
            tokenid2emb_raw = torch.load(cache_path, map_location=DEVICE) # Load directly to target device
            # Ensure embeddings are on the correct device, though map_location should handle it
            return {int(token_id): emb.to(DEVICE) for token_id, emb in tokenid2emb_raw.items()}
        else:
            print("⚙️ Building token ID embedding index from tokenizer vocab...")
            tokenizer = self.embedding_model.tokenizer
            vocab = tokenizer.get_vocab()
            filtered_items = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip()]

            tokens = [x[0] for x in filtered_items]
            ids = [x[1] for x in filtered_items]
            
            # Encode directly on target device
            token_embs = self.embedding_model.encode(tokens, convert_to_tensor=True, show_progress_bar=True, device=DEVICE)
            tokenid2emb = {int(i): emb for i, emb in zip(ids, token_embs)}

            torch.save(tokenid2emb, cache_path) # Save (will be on CPU if token_embs were moved back, or on GPU if not)
            print(f"✅ Cached embeddings for {len(tokenid2emb)} token ids to {cache_path}")
            # Ensure cuda if it wasn't already
            return {int(k): v.to(DEVICE) for k, v in tokenid2emb.items()}


    def encode(self, text: str) -> torch.Tensor:
        return self.embedding_model.encode(text, convert_to_tensor=True, device=DEVICE) # Ensure encoding on DEVICE

    def score(self, emb_a: torch.Tensor, emb_b: torch.Tensor) -> float:
        # Ensure embeddings are on the correct device if not already
        emb_a = emb_a.to(DEVICE)
        emb_b = emb_b.to(DEVICE)
        features = {
            "embedding_a": emb_a.unsqueeze(0),
            "embedding_b": emb_b.unsqueeze(0)
        }
        with torch.no_grad():
            logits = self.matching_head(features)["logits"]
            return torch.sigmoid(logits).item()

    @torch.no_grad()
    def predict_batch(self, answers, reasons, batch_size=32): # Increased default batch size
        assert len(answers) == len(reasons)
        all_probs = []
        for idx in range(0, len(answers), batch_size):
            batch_answers = answers[idx:idx+batch_size]
            batch_reasons = reasons[idx:idx+batch_size]

            emb_a = self.embedding_model.encode(batch_answers, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)
            emb_b = self.embedding_model.encode(batch_reasons, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = self.matching_head(features)
            logits = outputs["logits"].squeeze(-1)
            probs = torch.sigmoid(logits)
            all_probs.extend(probs.tolist())
        return all_probs

    def score_pair(self, reason: str, answer: str) -> float:
        # This is inefficient if called many times individually. Better to batch if possible.
        # For pre-computation, we call this per block, which is okay.
        probs = self.predict_batch([answer], [reason], batch_size=1)
        return probs[0]


def make_model_a_scorer(infer_a: MatchingInference):
    """ Returns a function that scores a block based on token IDs and sentence text. """
    tokenizer = infer_a.embedding_model.tokenizer
    def model_a_score(token_ids: List[int], sentence_text: str) -> float:
        valid_embs = [infer_a.tokenid2emb[tid] for tid in token_ids if tid in infer_a.tokenid2emb]
        if not valid_embs:
            return 0.0 # Or handle as an error/special value
        token_emb = torch.mean(torch.stack(valid_embs), dim=0)
        sent_emb = infer_a.encode(sentence_text) # This encodes the sentence_text
        
        # Ensure embeddings are on the correct device
        token_emb = token_emb.to(DEVICE)
        sent_emb = sent_emb.to(DEVICE)

        features = {
            "embedding_a": token_emb.unsqueeze(0),
            "embedding_b": sent_emb.unsqueeze(0)
        }
        with torch.no_grad():
            logits = infer_a.matching_head(features)["logits"]
            score = torch.sigmoid(logits).item()
        return score
    return model_a_score, tokenizer


def _check_sufficiency_with_verifier(
    current_scores_a: List[float],
    current_scores_b: List[float],
    verification_threshold: float
) -> Tuple[bool, float]:
    """
    Uses the global verifier model to check if the interleaved scores meet the threshold.
    Returns (is_sufficient, model_probability).
    """
    global verifier_model_global # Access the globally initialized verifier
    if verifier_model_global is None:
        raise ValueError("Verifier model not initialized globally.")

    if not current_scores_a or not current_scores_b: # Should imply not sufficient
        return False, 0.0 
    
    if len(current_scores_a) != len(current_scores_b):
         # This case should ideally not happen if logic is correct
        print("Warning: Mismatch in score list lengths for verifier.")
        return False, 0.0

    interleaved_scores = []
    for score_a, score_b in zip(current_scores_a, current_scores_b):
        interleaved_scores.append(score_a)
        interleaved_scores.append(score_b)

    if not interleaved_scores:
        return False, 0.0

    x = torch.tensor(interleaved_scores, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(DEVICE)
    lengths = torch.tensor([len(interleaved_scores)], dtype=torch.long).to(DEVICE) # lengths should be on CPU for pack_padded

    verifier_model_global.eval()
    with torch.no_grad():
        model_probability = verifier_model_global(x, lengths).item()

    sufficient = model_probability >= verification_threshold
    return sufficient, model_probability


def preprocess_samples_for_efficient_validation(
    data: List[Dict],
    infer_a: MatchingInference,
    infer_b: MatchingInference,
    token_ratio: float,
) -> List[Dict]:
    """
    Pre-processes all samples to compute block scores (score_a, score_b) once.
    Each item in the returned list will have:
    'label': original label
    'block_scores': list of {'a': score_a, 'b': score_b} for each valid block
    'alpha': number of valid (non-empty after filtering) blocks
    """
    print("🚀 Pre-processing samples to compute block scores...")
    preprocessed_data = []
    model_a_scorer, tokenizer_a = make_model_a_scorer(infer_a)

    for item in tqdm(data, desc="Pre-calculating block scores"):
        P, R_sentences, A_text = item["P"], item["R"], item["A"]

        filtered_sentences = [s for s in R_sentences if s and s.strip()]
        if not filtered_sentences:
            preprocessed_data.append({
                "label": item.get("label"),
                "block_scores": [],
                "alpha": 0
            })
            continue

        random.shuffle(filtered_sentences) # Shuffle once for consistent order during evaluation
        
        current_sample_block_scores = []
        for block_text in filtered_sentences:
            encoding = tokenizer_a(block_text, add_special_tokens=False, return_tensors='pt')
            block_token_ids = encoding["input_ids"][0].tolist()

            if not block_token_ids:
                # print(f"Warning: No tokens in block: '{block_text[:50]}...'. Skipping score for this block.")
                continue # Skip this block if it has no tokens

            sample_size = max(1, int(len(block_token_ids) * token_ratio))
            sample_size = min(sample_size, len(block_token_ids))
            selected_ids = random.sample(block_token_ids, sample_size)
            
            try:
                score_a = model_a_scorer(selected_ids, block_text)
                # For score_b, we need the answer A corresponding to this item
                score_b = infer_b.score_pair(block_text, A_text)
                current_sample_block_scores.append({'a': score_a, 'b': score_b})
            except Exception as e:
                print(f"Error scoring block '{block_text[:30]}...': {e}. Skipping block.")
                # Optionally, append a placeholder or skip, affecting 'alpha' count if definition changes
                continue
        
        preprocessed_data.append({
            "label": item.get("label"),
            "block_scores": current_sample_block_scores, # These are the scores for blocks that were successfully processed
            "alpha": len(filtered_sentences) # Total potential blocks after initial filtering
        })
    return preprocessed_data


def run_validation_on_preprocessed(
    preprocessed_sample_data: Dict,
    probing_ratio: float,
    verification_threshold: float
) -> bool:
    """
    Runs the validation logic for a single sample using its pre-computed block scores.
    """
    block_scores = preprocessed_sample_data["block_scores"]
    alpha = preprocessed_sample_data["alpha"] # Number of initially filtered (and shuffled) blocks

    if not block_scores: # If no blocks were scorable during preprocessing
        return False

    # num_initial_blocks_to_check is based on alpha (total potential blocks),
    # but we iterate through successfully scored `block_scores`.
    num_initial_blocks_to_check_ideal = min(alpha, math.ceil(max(1, probing_ratio * alpha)))
    
    # If alpha is 0, num_initial_blocks_to_check_ideal will be 0.
    # If alpha > 0, num_initial_blocks_to_check_ideal >= 1.

    accumulated_scores_a = []
    accumulated_scores_b = []
    
    pred_sufficient = False

    for i, score_pair in enumerate(block_scores):
        accumulated_scores_a.append(score_pair['a'])
        accumulated_scores_b.append(score_pair['b'])
        
        # Check for early stopping if enough *actually scored* blocks are processed
        # and this count meets the ideal number of blocks to check for early stopping.
        # The `i+1` here refers to the number of blocks *from the scorable list* we've processed.
        if (i + 1) >= num_initial_blocks_to_check_ideal:
            is_sufficient, _ = _check_sufficiency_with_verifier(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            if is_sufficient:
                pred_sufficient = True
                break 
    
    # If not early stopped and sufficient, do a final check with all accumulated scores
    if not pred_sufficient:
        if accumulated_scores_a: # Ensure there's something to check
            is_sufficient, _ = _check_sufficiency_with_verifier(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            pred_sufficient = is_sufficient
        else: # No scores were accumulated (e.g. all blocks failed during score_pair)
            pred_sufficient = False
            
    return pred_sufficient

# -------- 主函数入口 --------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_a_dir", type=str, required=True)
    parser.add_argument("--model_b_dir", type=str, required=True)
    parser.add_argument("--verifier_model_path", type=str, required=True, help="Path to the verifier model (e.g., best_model.pt)")
    parser.add_argument("--verifier_model_type", type=str, default="DeepSet", choices=["RNN", "DeepSet"], help="Type of verifier model")
    parser.add_argument("--verifier_hidden_dim", type=int, default=256, help="Hidden dimension for verifier model")
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--csv_output_path", type=str, required=True, help="Path to save the CSV results")
    # parser.add_argument("--block_size", type=int, default=10, help="Currently unused.") # Kept if needed later
    parser.add_argument("--token_ratio", type=float, default=0.1, help="Theta: Ratio of tokens to sample within a block for model A.")
    
    # Arguments for lists of thresholds and probing_ratios
    parser.add_argument("--verification_thresholds", type=float, nargs='+', required=True, help="List of verification thresholds to test.")
    parser.add_argument("--probing_ratios", type=float, nargs='+', required=True, help="List of probing ratios (gamma) to test.")
    parser.add_argument("--max_samples", type=int, default=None, help="Process a maximum number of samples for quick testing (e.g., 100).")


    args = parser.parse_args()

    print(f"Device: {DEVICE}")

    # Initialize Verifier Model Globally
    print("🧠 Initializing Verifier Model...")
    if args.verifier_model_type.upper() == "RNN":
        # Add rnn_type to parser if you want to customize RNN type (GRU/LSTM)
        verifier_model_global = RNNClassifier(hidden_dim=args.verifier_hidden_dim, rnn_type='GRU').to(DEVICE)
    else: # Default to DeepSet
        verifier_model_global = DeepSetClassifier(hidden_dim=args.verifier_hidden_dim).to(DEVICE)
    
    try:
        verifier_model_global.load_state_dict(torch.load(args.verifier_model_path, map_location=DEVICE))
        print(f"Verifier model loaded successfully from '{args.verifier_model_path}'")
    except FileNotFoundError:
        print(f"Warning: Verifier model '{args.verifier_model_path}' not found. Using a randomly initialized verifier.")
    except Exception as e:
        print(f"Error loading verifier model: {e}. Using a randomly initialized verifier.")
    verifier_model_global.eval()


    print("🚀 Initializing Inference Engines...")
    infer_a = MatchingInference(args.model_a_dir)
    infer_b = MatchingInference(args.model_b_dir)
    print("✅ Inference Engines Ready.")

    with open(args.data_path, "r", encoding="utf-8") as fin:
        all_data = json.load(fin)

    if args.max_samples is not None and args.max_samples > 0:
        print(f"🔪 Using a subset of {args.max_samples} samples for processing.")
        data_subset = all_data[:args.max_samples]
    else:
        data_subset = all_data
    
    # --- Stage 1: Pre-computation of scores for all samples ---
    # Create a map for P to A if A is constant for a P, to pass to preprocessing.
    # For simplicity, assuming A is directly in each item.
    # fixed_answer_map = {item['P']: item['A'] for item in data_subset} # If P->A is many-to-one
    
    preprocessed_samples = preprocess_samples_for_efficient_validation(
        data_subset, infer_a, infer_b, args.token_ratio #, fixed_answer_map
    )

    results_for_csv = []

    print(f"\n⚙️ Starting multi-parameter validation loop for {len(args.verification_thresholds)} thresholds and {len(args.probing_ratios)} probing ratios...")
    
    # --- Stage 2: Iterate through parameter combinations and evaluate using pre-computed scores ---
    for v_thresh in tqdm(args.verification_thresholds, desc="Thresholds"):
        for p_ratio in tqdm(args.probing_ratios, desc="Probing Ratios", leave=False):
            correct_predictions = 0
            total_labeled_samples = 0
            
            for sample_data in preprocessed_samples:
                label = sample_data["label"]
                # Skip samples without labels for accuracy calculation, or count them differently
                if label is None:
                    continue 
                
                total_labeled_samples += 1
                
                # Get the prediction for this sample with current v_thresh and p_ratio
                pred_is_sufficient = run_validation_on_preprocessed(
                    sample_data, p_ratio, v_thresh
                )
                
                if pred_is_sufficient == label:
                    correct_predictions += 1
            
            accuracy = (correct_predictions / total_labeled_samples) if total_labeled_samples > 0 else 0.0
            
            results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "accuracy": f"{accuracy:.4f}", # Store as formatted string or float
                "correct_predictions": correct_predictions,
                "total_labeled_samples": total_labeled_samples
            })
            # print(f"  Params: VT={v_thresh:.2f}, PR={p_ratio:.2f} -> Acc: {accuracy:.2%}")


    # --- Stage 3: Output results to CSV ---
    if results_for_csv:
        fieldnames = results_for_csv[0].keys()
        with open(args.csv_output_path, "w", newline='', encoding="utf-8") as fout_csv:
            writer = csv.DictWriter(fout_csv, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(results_for_csv)
        print(f"\n✅ Accuracy results for different combinations saved to {args.csv_output_path}")
    else:
        print("\n⚠️ No results to save to CSV. Check data or parameters.")

    print("\n----- Multi-Parameter Validation Summary -----")
    for res in results_for_csv:
        print(f"Threshold: {res['verification_threshold']}, Probing Ratio: {res['probing_ratio']}, Accuracy: {res['accuracy']} ({res['correct_predictions']}/{res['total_labeled_samples']})")